package com.bosssoft.platform.fasttcc.rpc.command;

import com.bosssoft.platform.fasttcc.support.OriginMethodInvoke;
import com.bosssoft.platform.fasttcc.TccTransaction;
import com.bosssoft.platform.fasttcc.Xid;
import com.jfireframework.baseutil.TRACEID;
import com.bosssoft.platform.fasttcc.TccTransactionManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

public class DefaultTccCommandHandler implements TccCommandHandler
{

    class Info
    {
        TccTransaction transaction;
        Set<String>    callIds    = new HashSet<>();
        long           createTime = System.currentTimeMillis();

        public Info(TccTransaction transaction)
        {
            this.transaction = transaction;
        }
    }

    private static final Logger                       LOGGER         = LoggerFactory.getLogger(DefaultTccCommandHandler.class);
    private              TccTransactionManager        manager;
    private              ConcurrentHashMap<Xid, Info> transactionMap = new ConcurrentHashMap<>();

    @Override
    public void setTccTransactionManager(TccTransactionManager manager)
    {
        this.manager = manager;
    }

    @Override
    public Object processTccMethodCallCommand(TccMethodCallCommand command, OriginMethodInvoke originMethodInvoke)
    {
        String traceId = TRACEID.currentTraceId();
        Xid    xid     = command.getXid();
        Info   exist   = transactionMap.get(xid);
        if (exist == null)
        {
            TccTransaction tccTransaction = manager.newParticipatorTccTransaction(xid, command.getNodeIdentifier());
            transactionMap.putIfAbsent(xid, new Info(tccTransaction));
            exist = transactionMap.get(xid);
            LOGGER.debug("traceId:{} 首次收到事务xid:{}的请求", traceId, xid);
        }
        synchronized (exist)
        {
            if (exist.callIds.add(command.getCallId()) == false)
            {
                throw new IllegalStateException("操作:" + command.getCallId() + "已经执行过或正在执行中");
            }
            if (exist.transaction.getPropagatedBy().equals(command.getNodeIdentifier()) == false)
            {
                throw new IllegalStateException("上级节点出现交叉，不允许这种情况，交叉节点为:" + command.getNodeIdentifier() + "和" + exist.transaction.getPropagatedBy());
            }
            try
            {
                LOGGER.debug("traceId:{} 关联TCC事务到当前线程", traceId);
                manager.associateTccTransaction(exist.transaction);
                return originMethodInvoke.invoke();
            }
            finally
            {
                manager.deAssociateTccTransaction();
            }
        }
    }

    @Override
    public void processTccCommitCommand(TccCommitCommand command)
    {
        Xid  xid  = command.getXid();
        Info info = transactionMap.get(xid);
        if (info == null)
        {
            return;
        }
        if (info.transaction.getPropagatedBy().equals(command.getNodeIdentifier()) == false)
        {
            throw new IllegalStateException("上级节点出现交叉，不允许这种情况，交叉节点为:" + command.getNodeIdentifier() + "和" + info.transaction.getPropagatedBy());
        }
        synchronized (info)
        {
            manager.associateTccTransaction(info.transaction);
            try
            {
                int state = info.transaction.getState();
                if (state == TccTransaction.ACTIVE)
                {
                    info.transaction.markForCommit();
                    info.transaction.processCompleteStage();
                }
                else if (state == TccTransaction.MARK_FOR_ROLLBACK)//不应该出现这种情况
                {
                    throw new IllegalStateException();
                }
                else if (state == TccTransaction.MARK_FOR_COMMIT || state == TccTransaction.FINISHED)
                {
                    return;
                }
            }
            finally
            {
                manager.deAssociateTccTransaction();
            }
        }
    }

    @Override
    public void processTccRollbackCommand(TccRollbackCommand command)
    {
        Xid  xid  = command.getXid();
        Info info = transactionMap.get(xid);
        if (info == null)
        {
            return;
        }
        if (info.transaction.getPropagatedBy().equals(command.getNodeIdentifier()) == false)
        {
            throw new IllegalStateException("上级节点出现交叉，不允许这种情况，交叉节点为:" + command.getNodeIdentifier() + "和" + info.transaction.getPropagatedBy());
        }
        synchronized (info)
        {
            manager.associateTccTransaction(info.transaction);
            try
            {
                int state = info.transaction.getState();
                if (state == TccTransaction.ACTIVE)
                {
                    info.transaction.markForRollback();
                    info.transaction.processCompleteStage();
                }
                else if (state == TccTransaction.MARK_FOR_COMMIT)//不应该出现这种情况
                {
                    throw new IllegalStateException();
                }
                else if (state == TccTransaction.MARK_FOR_ROLLBACK || state == TccTransaction.FINISHED)
                {
                    return;
                }
            }
            finally
            {
                manager.deAssociateTccTransaction();
            }
        }
    }

    @Override
    public void addTransaction(TccTransaction tccTransaction)
    {
        transactionMap.put(tccTransaction.getXid(), new Info(tccTransaction));
    }
}
